import numpy as np
import cv2
from matplotlib import pyplot as plt


def conv(image, kernel, mode='same'):

    if mode == 'fill':  #选择是否进行边缘填充
        h = kernel.shape[0] // 2   #卷积核的列整除2
        w = kernel.shape[1] // 2   #卷积核的行整除2
        #在原始图像边缘进行填充，常数填充，填数值0
        image = np.pad(image, ((h, h), (w, w), (0, 0)), 'constant')

    #进行卷积运算
    conv_b = convolve(image[:, :, 0], kernel)
    conv_g = convolve(image[:, :, 1], kernel)
    conv_r = convolve(image[:, :, 2], kernel)
    res = np.dstack([conv_b, conv_g, conv_r])
    return res


def convolve(image, kernel):
    h_kernel, w_kernel = kernel.shape  #获取卷积核的长宽，也就是行数和列数

    h_image, w_image = image.shape   #获取欲处理图片的长宽

    #计算卷积核中心点开始运动的点
    res_h = h_image - h_kernel + 1
    res_w = w_image - w_kernel + 1

    #生成一个numpy数组，用于保存处理后的图片
    res = np.zeros((res_h, res_w), np.uint8)

    for i in range(res_h):
        for j in range(res_w):
            #image处传入的是一个与卷积核一样大小矩阵，这个矩阵取自于欲处理图片的一部分
            #这个矩阵与卷核进行运算，用i与j来进行卷积核滑动
            res[i, j] = getMultiplyAns(image[i:i + h_kernel, j:j + w_kernel], kernel)

    return res

#两个数组(矩阵)，点对点相乘后进行累加
def getMultiplyAns(image, kernel):
    res = np.multiply(image, kernel).sum()
    if res > 255:
        return 255
    elif res<0:
        return 0
    else:
        return res

#测试函数
def testCovolve():
    image = cv2.imread("lena_noise.bmp")
    k1 = np.array([
        [1 / 9, 1 / 9, 1 / 9],
        [1 / 9, 1 / 9, 1 / 9],
        [1 / 9, 1 / 9, 1 / 9]
    ])
    k2 = np.array([[-1, 0, 1],
                   [-2, 0, 2],
                   [-1, 0, 1]])
    res1 = conv(image, k1, 'fill')
    cv2.imshow("Convoluted picture by k1", res1)
    res2 = conv(image, k2, 'fill')
    cv2.imshow("Convoluted picture by k2", res2)
    cv2.imshow('yuan shi tu xiang', image)
    cv2.waitKey(-1)


if __name__ == '__main__':
    testCovolve()